# -*- coding: utf-8 -*-
"""
Replication Code for Choi & Jung (2025) PSQ R&R
Topic: Immigration Populism Rhetoric Analysis using ParlaSent

This script performs the following:
1. Data preprocessing and filtering based on topic relevance.
2. Sentiment analysis using the 'classla/xlm-r-parlasent' XLM-RoBERTa model.
3. Statistical aggregation and visualization of sentiment trends by party and ideology.

Dependencies: pandas, numpy, torch, transformers, matplotlib, seaborn, scipy, tqdm
"""

import os
import torch
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
from scipy.stats import norm
from tqdm import tqdm
from transformers import AutoModelForSequenceClassification, AutoTokenizer, AutoConfig

# -----------------------------------------------------------------------------
# 1. Configuration & Constants
# -----------------------------------------------------------------------------
# Set random seed for reproducibility
SEED = 2025
torch.manual_seed(SEED)
np.random.seed(SEED)

# Paths (Users should adjust DATA_DIR to their local environment)
DATA_DIR = "./Data"  # Relative path recommended for replication packages
INPUT_FILE = os.path.join(DATA_DIR, "sentiment_analysis_python.csv")
OUTPUT_FILE = os.path.join(DATA_DIR, "parlasent_sentiment_results_final.csv")
FIGURE_DIR = os.path.join(DATA_DIR, "figures")

# Model Configuration
MODEL_NAME = "classla/xlm-r-parlasent"
MAX_TOKENS = 512
DEVICE = 0 if torch.cuda.is_available() else -1
tqdm.pandas()

# Plotting Style
plt.rcParams['font.family'] = 'Arial'
sns.set_theme(style="whitegrid")

# Party Classifications
POPULIST_PARTIES = [
    # Austria
    "BZÖ", "FPÖ", "Freiheitliche Partei Österreichs", "STRONACH", "Team Stronach", "SPÖ", "Grüne",
    # Denmark
    "DF", "Dansk Folkeparti",
    # Finland & Norway (Presumed codes)
    "FP", "NY", 
    # Germany
    "AfD", "Alternative für Deutschland", "PDS/LINKE", "Die Linke", "Partei des Demokratischen Sozialismus",
    # Sweden
    "SD", "Sweden Democrats", "Sverigedemokraterna", "NYD"
]

LEFT_POPULIST = {'PDS/LINKE', 'SPÖ', 'Grüne', 'Die Linke', 'Partei des Demokratischen Sozialismus'}
RIGHT_POPULIST = {'AfD', 'BZÖ', 'DF', 'FP', 'FPÖ', 'NY', 'NYD', 'STRONACH', 'Team Stronach', 'SD'}


# -----------------------------------------------------------------------------
# 2. Helper Functions
# -----------------------------------------------------------------------------

def load_data(path):
    """Loads the dataset and prints basic summary statistics."""
    if not os.path.exists(path):
        raise FileNotFoundError(f"Input file not found at: {path}")
    df = pd.read_csv(path)
    print(f"Data loaded. Shape: {df.shape}")
    return df

def classify_populist(party_name):
    """Classifies a party as Populist or Non-Populist based on substring matching."""
    if pd.isna(party_name):
        return "Non-Populist"
    if any(p in party_name for p in POPULIST_PARTIES):
        return "Populist"
    return "Non-Populist"

def classify_orientation(party_name):
    """Classifies populist parties into Left or Right orientation."""
    # Exact match or substring check could be used; using set lookup for precision based on list
    # Assuming party_name contains the standardized code. 
    # For robustness, we check if any key string exists in the party name.
    
    if pd.isna(party_name):
        return None
    
    # Check for Right Populist
    if any(p in party_name for p in RIGHT_POPULIST):
        return "Right Populist"
    # Check for Left Populist
    if any(p in party_name for p in LEFT_POPULIST):
        return "Left Populist"
        
    return None

def get_sentiment_label(score, mode=6):
    """Maps continuous sentiment score (0-5) to categorical labels."""
    score = np.clip(np.round(score), 0, 5)
    
    if mode == 6:
        mapping = {
            0: "Negative", 1: "Mixed Negative", 2: "Neutral Negative",
            3: "Neutral Positive", 4: "Mixed Positive", 5: "Positive"
        }
        return mapping[int(score)]
    elif mode == 3:
        # 0,1 -> Neg; 2,3 -> Neu; 4,5 -> Pos
        mapping = {0: "Negative", 1: "Neutral", 2: "Positive"}
        return mapping[int(score // 2)]

def calculate_ci(df, group_cols, value_col="sentiment_score"):
    """Aggregates mean, SD, and 95% CI for plotting."""
    summary = df.groupby(group_cols).agg(
        mean=(value_col, "mean"),
        sd=(value_col, "std"),
        n=(value_col, "count")
    ).reset_index()
    
    z_score = norm.ppf(0.975) # 95% Confidence Interval
    summary["se"] = summary["sd"] / np.sqrt(summary["n"])
    summary["ci"] = z_score * summary["se"]
    return summary

def save_plot(fig, filename):
    """Saves figure to the designated directory."""
    if not os.path.exists(FIGURE_DIR):
        os.makedirs(FIGURE_DIR)
    path = os.path.join(FIGURE_DIR, filename)
    fig.savefig(path, dpi=300, bbox_inches='tight')
    print(f"Figure saved: {path}")


# -----------------------------------------------------------------------------
# 3. Model Inference Class
# -----------------------------------------------------------------------------

class SentimentAnalyzer:
    def __init__(self, model_name, device_id):
        print(f"Loading model: {model_name} on device {device_id}...")
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        self.model = AutoModelForSequenceClassification.from_pretrained(model_name)
        self.device = torch.device(f"cuda:{device_id}" if device_id >= 0 else "cpu")
        self.model.to(self.device)
        self.model.eval()

    def predict_long_text(self, text):
        """
        Splits text into chunks of MAX_TOKENS, infers sentiment for each chunk,
        and returns the mean score.
        """
        if not isinstance(text, str) or len(text.strip()) == 0:
            return np.nan

        tokens = self.tokenizer(text, truncation=False, return_tensors='pt')['input_ids'][0]
        
        # Split into 512-token chunks
        chunks = [tokens[i:i + MAX_TOKENS] for i in range(0, len(tokens), MAX_TOKENS)]
        scores = []

        with torch.no_grad():
            for chunk in chunks:
                input_ids = chunk.unsqueeze(0).to(self.device)
                output = self.model(input_ids=input_ids)
                # Model returns logits -> single regression score implies index 0 of logits if reg head
                # But Parlasent usually returns classification logits. 
                # Assuming the user's logic `output.logits.squeeze().item()` was correct for a regression head
                # or a specific model config. If classification, we compute expected value.
                
                # Based on user's original simplified logic:
                logits = output.logits.squeeze()
                
                # If model output is a single scalar (regression)
                if logits.ndim == 0:
                    scores.append(logits.item())
                # If model output is class probabilities (classification 0-5)
                else:
                    probs = torch.nn.functional.softmax(logits, dim=0)
                    classes = torch.arange(len(probs), device=self.device)
                    expected_score = torch.sum(probs * classes).item()
                    scores.append(expected_score)

        return np.mean(scores) if scores else np.nan


# -----------------------------------------------------------------------------
# 4. Main Execution
# -----------------------------------------------------------------------------

def run_analysis():
    # --- Step 1: Data Loading & Pre-analysis ---
    df = load_data(INPUT_FILE)

    # Threshold Check
    print("Summary of 'sum_relevant_topic':")
    print(df["sum_relevant_topic"].describe())
    
    count_60 = (df["sum_relevant_topic"] >= 60).sum()
    print(f"Observations >= 60: {count_60} ({count_60 / len(df):.2%})")

    # Filter Data (Threshold > 50 as per original script logic)
    filtered_df = df[df["sum_relevant_topic"] > 50].copy()
    print(f"Filtered dataset size: {filtered_df.shape[0]}")

    # --- Step 2: Sentiment Inference ---
    # Check if results already exist to avoid re-running expensive inference
    if os.path.exists(OUTPUT_FILE):
        print("Existing results found. Loading...")
        filtered_df = pd.read_csv(OUTPUT_FILE)
    else:
        print("Starting Sentiment Inference...")
        analyzer = SentimentAnalyzer(MODEL_NAME, DEVICE)
        filtered_df["sentiment_score"] = filtered_df["raw_text"].progress_apply(analyzer.predict_long_text)
        
        # Create Categorical Labels
        filtered_df["sentiment_6cat"] = filtered_df["sentiment_score"].apply(lambda x: get_sentiment_label(x, 6))
        filtered_df["sentiment_3cat"] = filtered_df["sentiment_score"].apply(lambda x: get_sentiment_label(x, 3))
        
        # Save results
        filtered_df.to_csv(OUTPUT_FILE, index=False)
        print(f"Inference complete. Results saved to {OUTPUT_FILE}")

    # --- Step 3: Post-Processing Variables ---
    filtered_df["populist"] = filtered_df["party"].apply(classify_populist)
    filtered_df["populist_orientation"] = filtered_df["party"].apply(classify_orientation)
    
    # Clip scores to valid range [0, 5] for visualization
    filtered_df["sentiment_score"] = filtered_df["sentiment_score"].clip(0, 5)

    return filtered_df

# -----------------------------------------------------------------------------
# 5. Visualization Routines
# -----------------------------------------------------------------------------

def plot_populist_vs_non_by_country(df):
    """Figure: Average tone in immigration speeches by populist vs non-populist parties per country."""
    plot_df = calculate_ci(df, ["country", "populist"])
    
    palette = {"Populist": "firebrick", "Non-Populist": "silver"}
    
    g = sns.FacetGrid(plot_df, col="country", sharey=True, height=5, aspect=0.8)
    
    def barplot_ci(data, **kwargs):
        # FacetGrid passes data as a subset
        if data.empty: return
        ax = plt.gca()
        for _, row in data.iterrows():
            ax.bar(
                x=row["populist"],
                height=row["mean"],
                yerr=row["ci"],
                capsize=5,
                color=palette.get(row["populist"], "grey"),
                width=0.6,
                edgecolor="black"
            )

    g.map_dataframe(barplot_ci)
    g.set_titles("{col_name}")
    g.set_axis_labels("", "Average Sentiment Score")
    plt.ylim(0, 5) # Adjusted to full scale
    
    save_plot(plt.gcf(), "fig1_sentiment_country_populist.png")
    plt.show()

def plot_kde_distribution(df):
    """Figure: Kernel Density Estimate of sentiment scores."""
    plt.figure(figsize=(10, 6))
    palette = {"Populist": "firebrick", "Non-Populist": "silver"}
    
    ax = sns.kdeplot(
        data=df,
        x="sentiment_score",
        hue="populist",
        fill=True,
        common_norm=False,
        palette=palette,
        alpha=0.5,
        clip=(0, 5)
    )
    ax.get_legend().set_title(None)
    plt.xlabel("Sentiment Score")
    plt.ylabel("Density")
    plt.title("Distribution of Sentiment Scores")
    
    save_plot(plt.gcf(), "fig2_kde_populist_distribution.png")
    plt.show()

def plot_all_parties_by_country(df):
    """Figure: Sentiment by All Parties (Facet by Country)."""
    summary_df = calculate_ci(df, ["country", "party", "populist"])
    color_map = {"Populist": "firebrick", "Non-Populist": "silver"}

    g = sns.FacetGrid(summary_df, col="country", col_wrap=3, sharex=True, sharey=False, height=4, aspect=1.2)

    def draw_barh(data, **kwargs):
        ax = plt.gca()
        data = data.sort_values("mean", ascending=True)
        bars = ax.barh(
            y=data["party"],
            width=data["mean"],
            color=data["populist"].map(color_map),
            edgecolor="black",
            height=0.6
        )
        # Add error bars
        ax.errorbar(
            x=data["mean"],
            y=data["party"],
            xerr=data["ci"],
            fmt='none',
            ecolor='black',
            capsize=2
        )
        ax.set_xlim(0, 5)

    g.map_dataframe(draw_barh)
    g.set_titles("{col_name}")
    g.set_xlabels("Avg. Sentiment Score")
    
    save_plot(plt.gcf(), "fig3_all_parties_sentiment.png")
    plt.show()

def plot_left_vs_right_populist(df):
    """Figure: Sentiment comparison between Left and Right Populist parties."""
    pop_df = df[df["populist_orientation"].notnull()].copy()
    summary_df = calculate_ci(pop_df, ["party", "country", "populist_orientation"])
    
    summary_df["label"] = summary_df["party"] + " (" + summary_df["country"] + ")"
    summary_df = summary_df.sort_values("mean")
    
    color_map = {"Left Populist": "steelblue", "Right Populist": "firebrick"}
    
    plt.figure(figsize=(10, 8))
    ax = plt.gca()
    
    bars = ax.barh(
        y=summary_df["label"],
        width=summary_df["mean"],
        color=summary_df["populist_orientation"].map(color_map),
        edgecolor="black"
    )
    
    ax.errorbar(
        x=summary_df["mean"],
        y=summary_df["label"],
        xerr=summary_df["ci"],
        fmt='none',
        ecolor='black',
        capsize=3
    )
    
    plt.xlabel("Average Sentiment Score")
    plt.title("Sentiment Scores: Left vs. Right Populist Parties")
    plt.xlim(0, 5)
    plt.grid(axis='x', linestyle='--', alpha=0.7)
    
    save_plot(plt.gcf(), "fig4_left_vs_right_populist.png")
    plt.show()


# -----------------------------------------------------------------------------
# Main Entry Point
# -----------------------------------------------------------------------------
if __name__ == "__main__":
    print("=== Choi & Jung 2025 PSQ R&R Replication Start ===")
    
    # Run analysis pipeline
    final_df = run_analysis()
    
    # Generate Figures
    print("\nGenerating Figures...")
    plot_populist_vs_non_by_country(final_df)
    plot_kde_distribution(final_df)
    plot_all_parties_by_country(final_df)
    plot_left_vs_right_populist(final_df)
    
    print("\n=== Replication Complete ===")